{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2sLabK5fIElb"
      },
      "source": [
        "##### Copyright 2021 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DwsgIGum5PF6"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import numpy as np\n",
        "from scipy.io import wavfile\n",
        "from scipy.signal import resample\n",
        "import tensorflow.compat.v1 as tf\n",
        "from tensorflow.io import gfile\n",
        "from IPython.display import Audio, display\n",
        "from ipywidgets import widgets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QynSWe_pKM6s"
      },
      "outputs": [],
      "source": [
        "# @title Mount Google Drive\n",
        "from google.colab import drive\n",
        "ROOT_DIR = '/content/gdrive'\n",
        "drive.mount(ROOT_DIR, force_remount=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "hag_CDbLJ4nT"
      },
      "outputs": [],
      "source": [
        "# @title Helper function for playing audio.\n",
        "def PlaySound(samples, sample_rate=16000):\n",
        "  out = widgets.Output()\n",
        "  with out:\n",
        "    display(Audio(samples, rate=sample_rate))\n",
        "  display(out)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "lX8a_0IqNrWE"
      },
      "outputs": [],
      "source": [
        "# @title Helper class for separation model inference.\n",
        "class SeparationModel(object):\n",
        "  \"\"\"Tensorflow audio separation model.\"\"\"\n",
        "\n",
        "  def __init__(self,\n",
        "               checkpoint_path,\n",
        "               metagraph_path):\n",
        "    self.graph = tf.Graph()\n",
        "    self.sess = tf.Session(graph=self.graph)\n",
        "    with self.graph.as_default():\n",
        "      new_saver = tf.train.import_meta_graph(metagraph_path)\n",
        "      new_saver.restore(self.sess, checkpoint_path)\n",
        "    self.input_placeholder = self.graph.get_tensor_by_name(\n",
        "        'input_audio/receiver_audio:0')\n",
        "    self.output_tensor = self.graph.get_tensor_by_name('denoised_waveforms:0')\n",
        "\n",
        "\n",
        "  def separate(self, mixture_waveform):\n",
        "    \"\"\"Separates a mixture waveform into sources.\n",
        "\n",
        "    Args:\n",
        "      mixture_waveform: numpy.ndarray of shape (num_mics, num_samples).\n",
        "\n",
        "    Returns:\n",
        "      numpy.ndarray of separated waveforms of shape (num_sources, num_samples).\n",
        "      dict of additional tensor outputs.\n",
        "    \"\"\"\n",
        "    mixture_waveform_input = np.expand_dims(mixture_waveform, 0)\n",
        "    feed_dict = {self.input_placeholder: mixture_waveform_input}\n",
        "\n",
        "    separated_waveforms = self.sess.run(self.output_tensor, feed_dict=feed_dict)\n",
        "    return separated_waveforms[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "OE6p4LgrJJD3"
      },
      "outputs": [],
      "source": [
        "# @title Download the pre-trained speech enhancement model files.\n",
        "%%shell\n",
        "gsutil cp -r gs://gresearch/cochlear_implant/speech_enhancement_model /tmp/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5dwPjD3h5TWC"
      },
      "outputs": [],
      "source": [
        "# @title Load speech enhancement model.\n",
        "MODEL_PATH = '/tmp/speech_enhancement_model'\n",
        "\n",
        "checkpoint = os.path.join(MODEL_PATH, 'checkpoint')\n",
        "metagraph = os.path.join(MODEL_PATH, 'inference.meta')\n",
        "model = SeparationModel(checkpoint, metagraph)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_dgSWLP3JlHJ"
      },
      "outputs": [],
      "source": [
        "# @title Get some wav paths.\n",
        "PATH_AUDIO = 'gdrive/My Drive/cihack_audio'  # E.g. /content/gdrive/My Drive/cihack_audio\n",
        "PATH_ENHANCED = PATH_AUDIO + '_enhanced'\n",
        "\n",
        "audio_clip_matcher = '*.wav'  #@param\n",
        "wavs = gfile.glob(os.path.join(PATH_AUDIO, audio_clip_matcher))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "qZ1iB6qCbSAb"
      },
      "outputs": [],
      "source": [
        "#@title File read/write functions.\n",
        "def write_wav(filename, waveform, sample_rate=16000):\n",
        "  \"\"\"Write a audio waveform (float numpy array) as .wav file.\"\"\"\n",
        "  wavfile.write(\n",
        "      filename, sample_rate,\n",
        "      np.round(np.clip(waveform * 2**15, -32768, 32767)).astype(np.int16))\n",
        "\n",
        "def read_wav(wav_path, sample_rate=16000, channel=None):\n",
        "  \"\"\"Read a wav file as numpy array.\n",
        "\n",
        "  Args:\n",
        "    wav_path: String, path to .wav file.\n",
        "    sample_rate: Int, sample rate for audio to be converted to.\n",
        "    channel: Int, option to select a particular channel for stereo audio.\n",
        "\n",
        "  Returns:\n",
        "    Audio as float numpy array.\n",
        "  \"\"\"\n",
        "  sr_read, x = wavfile.read(wav_path)\n",
        "  x = x.astype(np.float32) / (2**15)\n",
        "\n",
        "  if sr_read != sample_rate:\n",
        "    x = resample(x, int(round((float(sample_rate) / sr_read) * len(x))))\n",
        "  if x.ndim \u003e 1 and channel is not None:\n",
        "    return x[:, channel]\n",
        "  return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7m_ggHiOJ0Ey"
      },
      "outputs": [],
      "source": [
        "# @title Enhance some wavs and play audio.\n",
        "for wav in wavs:\n",
        "  print('wav path:', wav)\n",
        "  print('Input')\n",
        "  receiver_audio = read_wav(wav)\n",
        "\n",
        "  PlaySound(receiver_audio, 16000)\n",
        "\n",
        "  enhanced = model.separate(receiver_audio[np.newaxis])\n",
        "  output_path = os.path.join(PATH_ENHANCED, os.path.basename(wav))\n",
        "  output_path = os.path.join(PATH_ENHANCED, os.path.basename(wav))\n",
        "  gfile.makedirs(os.path.dirname(output_path))\n",
        "  write_wav(output_path, enhanced.T, 16000)\n",
        "\n",
        "  print('Speech estimate')\n",
        "  PlaySound(enhanced[0], 16000)\n",
        "  print('Noise estimate')\n",
        "  PlaySound(enhanced[1], 16000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7YMMI2TfQi9J"
      },
      "outputs": [],
      "source": [
        ""
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "speech_enhancement_inference.ipynb",
      "provenance": [
        {
          "file_id": "1aWvNK_mN8Di8OPWa9DBGClYE7La4yBD3",
          "timestamp": 1614196108373
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
